{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#导入相关的库\n",
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from tensorflow.keras import Model ,layers\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = pd.read_csv('train.csv')    #读取数据\n",
    "x = data.loc[:,data.columns!='label'].values/255.0     #获取图片数据并归一化到0-1\n",
    "x = np.array(x,np.float32)\n",
    "y = data['label'].values    #获取标签  \n",
    "data = tf.data.Dataset.from_tensor_slices((x,y))          #构造数据集\n",
    "data_loader = data.repeat().shuffle(1000).batch(128).prefetch(1)     #分批次加载数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#定义生生成器\n",
    "class Generator(Model):\n",
    "    def __init__(self):\n",
    "        super(Generator,self).__init__()\n",
    "        self.fc1 = layers.Dense(7*7*128)\n",
    "        self.bn1 = layers.BatchNormalization()\n",
    "        self.conv2tr1 = layers.Conv2DTranspose(64,5,strides=2,padding='SAME')   #反卷积操作\n",
    "        self.bn2 = layers.BatchNormalization()\n",
    "        self.conv2tr2 = layers.Conv2DTranspose(1,5,strides=2,padding='SAME')\n",
    "    def call(self,x,is_training=False):\n",
    "        x = self.fc1(x)\n",
    "        x = self.bn1(x,training=is_training)\n",
    "        x  = tf.nn.leaky_relu(x)\n",
    "        x = tf.reshape(x,shape=[-1,7,7,128])\n",
    "        \n",
    "        x = self.conv2tr1(x)\n",
    "        x = self.bn2(x,training=is_training)\n",
    "        x = tf.nn.leaky_relu(x)\n",
    "        x = self.conv2tr2(x)\n",
    "        return x\n",
    "#定义判别器\n",
    "class Discriminator(Model):\n",
    "    def __init__(self):\n",
    "        super(Discriminator,self).__init__()\n",
    "        self.conv1 = layers.Conv2D(64,5,strides=2,padding='SAME')\n",
    "        self.bn1 = layers.BatchNormalization()\n",
    "        self.conv2 = layers.Conv2D(128,5,strides=2,padding='SAME')\n",
    "        self.bn2 = layers.BatchNormalization()\n",
    "        self.flatten = layers.Flatten()\n",
    "        self.fc1 = layers.Dense(1024)\n",
    "        self.bn3 = layers.BatchNormalization()\n",
    "        self.fc2 = layers.Dense(2)\n",
    "    def call(self,x,is_training=False):\n",
    "        x = tf.reshape(x,[-1,28,28,1])\n",
    "        x =  x = self.conv1(x)\n",
    "        x = self.bn1(x, training=is_training)\n",
    "        x = tf.nn.leaky_relu(x)\n",
    "        x = self.conv2(x)\n",
    "        x = self.bn2(x, training=is_training)\n",
    "        x = tf.nn.leaky_relu(x)\n",
    "        x = self.flatten(x)\n",
    "        x = self.fc1(x)\n",
    "        x = self.bn3(x, training=is_training)\n",
    "        x = tf.nn.leaky_relu(x)\n",
    "        return self.fc2(x)\n",
    "\n",
    "generator = Generator()\n",
    "discriminator = Discriminator()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#定义生成器的损失函数\n",
    "def generator_loss(image): \n",
    "    gen_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=image,labels=tf.ones([128],dtype=tf.int32)))\n",
    "    return gen_loss\n",
    "#定义判别器的损失函数\n",
    "def discriminator_loss(fake_image,real_image):\n",
    "    fake_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=fake_image,labels=tf.ones([128],dtype=tf.int32)))\n",
    "    real_loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits=real_image,labels=tf.zeros([128],dtype=tf.int32)))\n",
    "    return fake_loss + real_loss\n",
    "\n",
    "#定义优化函数\n",
    "optimizer_gen = tf.optimizers.Adam(learning_rate=0.0002)\n",
    "optimizer_disc =  tf.optimizers.Adam(learning_rate=0.0002)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def  run_optimizer(real_image):\n",
    "    real_images = real_image*2. - 1.\n",
    "    noise = np.random.normal(-1.,1.,size=[128,100]).astype(np.float32)\n",
    "    #先训练判别器\n",
    "    with tf.GradientTape() as g:\n",
    "        fake_images = generator(noise,is_training=True)     #生成假图片\n",
    "        \n",
    "        disc_fake = discriminator(fake_images,is_training=True)    #将假图片给判别器识别\n",
    "        \n",
    "        disc_real = discriminator(real_images,is_training=True)     #将真图片给判别器识别\n",
    "        \n",
    "        disc_loss = discriminator_loss(disc_fake,disc_real)\n",
    "    gradient_disc = g.gradient(disc_loss,discriminator.trainable_variables)\n",
    "    optimizer_disc.apply_gradients(zip(gradient_disc,discriminator.trainable_variables))\n",
    "    \n",
    "    #后训练生成器\n",
    "    noise = np.random.normal(-1.,1.,size=[128,100]).astype(np.float32)\n",
    "    with tf.GradientTape() as g:\n",
    "            fake_images = generator(noise,is_training=True)\n",
    "            disc_fake = discriminator(fake_images,is_training=True)\n",
    "            gen_loss = generator_loss(disc_fake)\n",
    "    gradient_gen = g.gradient(gen_loss,generator.trainable_variables)\n",
    "    optimizer_gen.apply_gradients(zip(gradient_gen,generator.trainable_variables))\n",
    "    \n",
    "    return (gen_loss,disc_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#开始训练\n",
    "for step,(batch_x,_) in enumerate(data_loader.take(1000)):\n",
    "    (gen_loss,disc_loss) = run_optimizer(batch_x)\n",
    "    \n",
    "    if step % 10 == 0:\n",
    "        print(\"step: %i, gen_loss: %f, disc_loss: %f\" % (step, gen_loss, disc_loss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
